from datasets import load_dataset

train_dataset = []
test_dateset = []
length = [32, 64, 128, 256]

def split_dataset(split_ratio, seed):
    nums = 0
    for l in length:
        dataset = load_dataset("swj0419/WikiMIA", split=f"WikiMIA_length{l}")
        dataset.shuffle(seed=seed)
        nums += len(dataset)
        split = dataset.train_test_split(test_size=split_ratio)
        test_dateset.append(split["train"])
        train_dataset.append(split["test"]) # dataset for fine-tuning
        print("train size: {}, test size: {}".format(len(split["test"]), len(split["train"])))
    print(f"Total number of samples: {nums}")
    return train_dataset, test_dateset # return the dataset for fine-tuning and testing

def dataset_to_json(dataset):
    all_data = []
    class_map = {"1": 0, "0":0}
    for i in range(len(dataset)):
        text = dataset[i]["input"]
        label = dataset[i]["label"]
        assert len(text) == len(label)
        for j in range(len(text)):
            all_data.append({"text": text[j], "label": label[j], "length": length[i]})
            class_map[str(label[j])] += 1
    print("members: {}, non-members: {}".format(class_map["1"], class_map["0"]))
    return all_data
        
if __name__ == "__main__":
    split_ratio = 0.3
    seed = 42
    train_dataset, test_dateset = split_dataset(split_ratio, seed)
    train_data = dataset_to_json(train_dataset) # dataset for fine-tuning
    test_data = dataset_to_json(test_dateset) # dataset for testing
    print(f"Number of samples for fine-tuning: {len(train_data)}")
    print(f"Number of samples for testing: {len(test_data)}")
    
    # save the dataset
    import json
    with open("data/train_data.json", "w") as f:
        json.dump(train_data, f)
    with open("data/test_data.json", "w") as f:
        json.dump(test_data, f)
        
    print("Data saved successfully!")
        
    
    